{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "sys.version_info(major=3, minor=6, micro=10, releaselevel='final', serial=0)\n",
      "matplotlib 3.1.2\n",
      "numpy 1.18.1\n",
      "pandas 0.25.3\n",
      "sklearn 0.22.1\n",
      "tensorflow 2.0.0\n",
      "tensorflow_core.keras 2.2.4-tf\n"
     ]
    }
   ],
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "\n",
    "print(tf.__version__)\n",
    "print(sys.version_info)\n",
    "for module in mpl,np,pd,sklearn,tf,keras:\n",
    "    print(module.__name__,module.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Then what?\n",
      "他来到了网易杭研大厦？\n"
     ]
    }
   ],
   "source": [
    "en_cmn_file_path = 'data/cmn_en/cmn_proc.txt'\n",
    "\n",
    "# 将unicod编码转化为ascii，如果有多个ascii组成，则拆分，去掉重音\n",
    "import unicodedata\n",
    "def unicode_to_ascii(s):\n",
    "    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')\n",
    "\n",
    "# test\n",
    "en_sentence = 'Then what?'\n",
    "cmn_sentence = '他来到了网易杭研大厦？'\n",
    "\n",
    "print(unicode_to_ascii(en_sentence))\n",
    "print(unicode_to_ascii(cmn_sentence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<start> then what ? <end>\n",
      "<start> 他来到了网易杭研大厦 ？ <end>\n"
     ]
    }
   ],
   "source": [
    "# 西方语言常用：标点符号和字母分开\n",
    "import re\n",
    "def preprocess_sentence(s):\n",
    "    # 转化成ascii，变小写去空格\n",
    "    s = unicode_to_ascii(s.lower().strip())\n",
    "    \n",
    "    # 标点符号前后加空格\n",
    "    s = re.sub(r'([?.!,。，！？‘’“”])', r' \\1 ', s)\n",
    "    # 多余的空格变成一个空格\n",
    "    s = re.sub(r'[\" \"]+', ' ', s)\n",
    "    # 除了标点符号和字母外都是空格\n",
    "#     s = re.sub(r'[^a-zA-Z?.!,¿]', ' ', s)\n",
    "    # 去掉前后空格\n",
    "    s = s.rstrip().strip()\n",
    "    # 前后加标记\n",
    "    s = '<start> ' + s + ' <end>'\n",
    "    return s\n",
    "\n",
    "# test\n",
    "print(preprocess_sentence(en_sentence))\n",
    "print(preprocess_sentence(cmn_sentence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<start> 他 来到 了 网易 杭研 大厦 ？ <end>\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "\n",
    "seg_list = jieba.cut_for_search(cmn_sentence)\n",
    "seg_list = ' '.join(seg_list)\n",
    "print(preprocess_sentence(seg_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<start> if a person has not had a chance to acquire his target language by the time he's an adult , he's unlikely to be able to reach native speaker level in that language . <end>\n",
      "<start> 如果 一個 人 在 成人 前 沒 有 機會習 得 目標 語言 ， 他 對 該 語言 的 認識 達 到 母語者 程度 的 機會 是 相當 小 的 。 <end>\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "# 解析文件\n",
    "def parse_data(filename):\n",
    "    lines = open(filename, encoding='utf-8').read().strip().split('\\n')\n",
    "    sentence_pairs = [line.split('\\t') for line in lines]\n",
    "    preprocess_sentence_pairs = [\n",
    "        (preprocess_sentence(en), preprocess_sentence(' '.join(jieba.cut_for_search(cmn)))) for en, cmn in sentence_pairs]\n",
    "    # 解包和zip联用：将每一个元组解开，重新组合成两个新的列表\n",
    "    return zip(*preprocess_sentence_pairs)\n",
    "\n",
    "en_dataset, cmn_dataset = parse_data(en_cmn_file_path)\n",
    "print(en_dataset[-1])\n",
    "print(cmn_dataset[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33 36\n"
     ]
    }
   ],
   "source": [
    "def tokenizer(lang):\n",
    "    lang_tokenizer = keras.preprocessing.text.Tokenizer(num_words=None, filters='', split=' ')\n",
    "    # 统计词频，生成词表\n",
    "    lang_tokenizer.fit_on_texts(lang)\n",
    "    # id化\n",
    "    tensor = lang_tokenizer.texts_to_sequences(lang)\n",
    "    # padding\n",
    "    tensor = keras.preprocessing.sequence.pad_sequences(tensor, padding='post')\n",
    "    return tensor, lang_tokenizer\n",
    "\n",
    "input_tensor, input_tokenizer = tokenizer(cmn_dataset)\n",
    "output_tensor, output_tokenizer = tokenizer(en_dataset)\n",
    "\n",
    "def max_length(tensor):\n",
    "    return max(len(t) for t in tensor)\n",
    "\n",
    "max_length_input = max_length(input_tensor)\n",
    "max_length_output = max_length(output_tensor)\n",
    "print(max_length_input, max_length_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(17296, 4325, 17296, 4325)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练集和验证集切分\n",
    "from sklearn.model_selection import train_test_split\n",
    "input_train, input_eval, output_train, output_eval = train_test_split(\n",
    "    input_tensor, output_tensor, test_size = 0.2)\n",
    "\n",
    "len(input_train), len(input_eval), len(output_train), len(output_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 --> <start>\n",
      "8 --> 他\n",
      "738 --> 經常\n",
      "4851 --> 引用\n",
      "4162 --> 莎士\n",
      "4163 --> 比亞\n",
      "3 --> 。\n",
      "2 --> <end>\n",
      "\n",
      "1 --> <start>\n",
      "11 --> he\n",
      "221 --> often\n",
      "3067 --> quotes\n",
      "66 --> from\n",
      "2749 --> shakespeare\n",
      "3 --> .\n",
      "2 --> <end>\n"
     ]
    }
   ],
   "source": [
    "# 验证tokenizer是否转化正确\n",
    "def convert(example, tokenizer):\n",
    "    for t in example:\n",
    "        if t != 0:\n",
    "            print('%d --> %s' % (t, tokenizer.index_word[t]))\n",
    "            \n",
    "convert(input_train[0], input_tokenizer)\n",
    "print()\n",
    "convert(output_train[0], output_tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_dataset(input_tensor, output_tensor, batch_size, epochs, shuffle):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((input_tensor, output_tensor))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size, drop_remainder = True)\n",
    "    return dataset\n",
    "\n",
    "batch_size = 32\n",
    "epochs = 10\n",
    "\n",
    "train_dataset = make_dataset(input_train, output_train, batch_size, epochs, True)\n",
    "eval_dataset = make_dataset(input_eval, output_eval, batch_size, 1, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(32, 33)\n",
      "(32, 36)\n",
      "tf.Tensor(\n",
      "[[  1  24  40 ...   0   0   0]\n",
      " [  1   4  42 ...   0   0   0]\n",
      " [  1  14 104 ...   0   0   0]\n",
      " ...\n",
      " [  1   4 333 ...   0   0   0]\n",
      " [  1   7  18 ...   0   0   0]\n",
      " [  1  12  73 ...   0   0   0]], shape=(32, 33), dtype=int32)\n",
      "tf.Tensor(\n",
      "[[   1   12  168 ...    0    0    0]\n",
      " [   1  103  594 ...    0    0    0]\n",
      " [   1   12  135 ...    0    0    0]\n",
      " ...\n",
      " [   1    5  468 ...    0    0    0]\n",
      " [   1    7   28 ...    0    0    0]\n",
      " [   1 1792  151 ...    0    0    0]], shape=(32, 36), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "for x, y in train_dataset.take(1):\n",
    "    print(x.shape)\n",
    "    print(y.shape)\n",
    "    print(x)\n",
    "    print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
